{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# Deploy Quantized Models using Torch-TensorRT\n\nHere we demonstrate how to deploy a model quantized to FP8 using the Dynamo frontend of Torch-TensorRT\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Imports and Model Definition\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\n\nimport modelopt.torch.quantization as mtq\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch_tensorrt as torchtrt\nimport torchvision.datasets as datasets\nimport torchvision.transforms as transforms\nfrom modelopt.torch.quantization.utils import export_torch_mode\n\n\nclass VGG(nn.Module):\n    def __init__(self, layer_spec, num_classes=1000, init_weights=False):\n        super(VGG, self).__init__()\n\n        layers = []\n        in_channels = 3\n        for l in layer_spec:\n            if l == \"pool\":\n                layers.append(nn.MaxPool2d(kernel_size=2, stride=2))\n            else:\n                layers += [\n                    nn.Conv2d(in_channels, l, kernel_size=3, padding=1),\n                    nn.BatchNorm2d(l),\n                    nn.ReLU(),\n                ]\n                in_channels = l\n\n        self.features = nn.Sequential(*layers)\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.classifier = nn.Sequential(\n            nn.Linear(512 * 1 * 1, 4096),\n            nn.ReLU(),\n            nn.Dropout(),\n            nn.Linear(4096, 4096),\n            nn.ReLU(),\n            nn.Dropout(),\n            nn.Linear(4096, num_classes),\n        )\n        if init_weights:\n            self._initialize_weights()\n\n    def _initialize_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, 0, 0.01)\n                nn.init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        x = self.features(x)\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n        x = self.classifier(x)\n        return x\n\n\ndef vgg16(num_classes=1000, init_weights=False):\n    vgg16_cfg = [\n        64,\n        64,\n        \"pool\",\n        128,\n        128,\n        \"pool\",\n        256,\n        256,\n        256,\n        \"pool\",\n        512,\n        512,\n        512,\n        \"pool\",\n        512,\n        512,\n        512,\n        \"pool\",\n    ]\n    return VGG(vgg16_cfg, num_classes, init_weights)\n\n\nPARSER = argparse.ArgumentParser(\n    description=\"Load pre-trained VGG model and then tune with FP8 and PTQ. For having a pre-trained VGG model, please refer to https://github.com/pytorch/TensorRT/tree/main/examples/int8/training/vgg16\"\n)\nPARSER.add_argument(\n    \"--ckpt\", type=str, required=True, help=\"Path to the pre-trained checkpoint\"\n)\nPARSER.add_argument(\n    \"--batch-size\",\n    default=128,\n    type=int,\n    help=\"Batch size for tuning the model with PTQ and FP8\",\n)\n\nargs = PARSER.parse_args()\n\nmodel = vgg16(num_classes=10, init_weights=False)\nmodel = model.cuda()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Load the pre-trained model weights\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ckpt = torch.load(args.ckpt)\nweights = ckpt[\"model_state_dict\"]\n\nif torch.cuda.device_count() > 1:\n    from collections import OrderedDict\n\n    new_state_dict = OrderedDict()\n    for k, v in weights.items():\n        name = k[7:]  # remove `module.`\n        new_state_dict[name] = v\n    weights = new_state_dict\n\nmodel.load_state_dict(weights)\n# Don't forget to set the model to evaluation mode!\nmodel.eval()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Load training dataset and define loss function for PTQ\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "training_dataset = datasets.CIFAR10(\n    root=\"./data\",\n    train=True,\n    download=True,\n    transform=transforms.Compose(\n        [\n            transforms.RandomCrop(32, padding=4),\n            transforms.RandomHorizontalFlip(),\n            transforms.ToTensor(),\n            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n        ]\n    ),\n)\ntraining_dataloader = torch.utils.data.DataLoader(\n    training_dataset,\n    batch_size=args.batch_size,\n    shuffle=True,\n    num_workers=2,\n    drop_last=True,\n)\n\ndata = iter(training_dataloader)\nimages, _ = next(data)\n\ncrit = nn.CrossEntropyLoss()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Define Calibration Loop for quantization\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def calibrate_loop(model):\n    # calibrate over the training dataset\n    total = 0\n    correct = 0\n    loss = 0.0\n    for data, labels in training_dataloader:\n        data, labels = data.cuda(), labels.cuda(non_blocking=True)\n        out = model(data)\n        loss += crit(out, labels)\n        preds = torch.max(out, 1)[1]\n        total += labels.size(0)\n        correct += (preds == labels).sum().item()\n\n    print(\"PTQ Loss: {:.5f} Acc: {:.2f}%\".format(loss / total, 100 * correct / total))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Tune the pre-trained model with FP8 and PTQ\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "quant_cfg = mtq.FP8_DEFAULT_CFG\n# PTQ with in-place replacement to quantized modules\nmtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)\n# model has FP8 qdq nodes at this point"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Inference\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Load the testing dataset\ntesting_dataset = datasets.CIFAR10(\n    root=\"./data\",\n    train=False,\n    download=True,\n    transform=transforms.Compose(\n        [\n            transforms.ToTensor(),\n            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n        ]\n    ),\n)\n\ntesting_dataloader = torch.utils.data.DataLoader(\n    testing_dataset,\n    batch_size=args.batch_size,\n    shuffle=False,\n    num_workers=2,\n    drop_last=True,\n)  # set drop_last=True to drop the last incomplete batch for static shape `torchtrt.dynamo.compile()`\n\nwith torch.no_grad():\n    with export_torch_mode():\n        # Compile the model with Torch-TensorRT Dynamo backend\n        input_tensor = images.cuda()\n        exp_program = torch.export.export(model, (input_tensor,))\n        trt_model = torchtrt.dynamo.compile(\n            exp_program,\n            inputs=[input_tensor],\n            enabled_precisions={torch.float8_e4m3fn},\n            min_block_size=1,\n            debug=False,\n        )\n        # You can also use torch compile path to compile the model with Torch-TensorRT:\n        # trt_model = torch.compile(model, backend=\"tensorrt\")\n\n        # Inference compiled Torch-TensorRT model over the testing dataset\n        total = 0\n        correct = 0\n        loss = 0.0\n        class_probs = []\n        class_preds = []\n        for data, labels in testing_dataloader:\n            data, labels = data.cuda(), labels.cuda(non_blocking=True)\n            out = trt_model(data)\n            loss += crit(out, labels)\n            preds = torch.max(out, 1)[1]\n            class_probs.append([F.softmax(i, dim=0) for i in out])\n            class_preds.append(preds)\n            total += labels.size(0)\n            correct += (preds == labels).sum().item()\n\n        test_probs = torch.cat([torch.stack(batch) for batch in class_probs])\n        test_preds = torch.cat(class_preds)\n        test_loss = loss / total\n        test_acc = correct / total\n        print(\"Test Loss: {:.5f} Test Acc: {:.2f}%\".format(test_loss, 100 * test_acc))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.0"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}